# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""RetinaNet."""
import collections
from typing import Any, Mapping, List, Optional, Union, Sequence

import tensorflow as tf, tf_keras

from official.vision.ops import anchor


@tf_keras.utils.register_keras_serializable(package='Vision')
class RetinaNetModel(tf_keras.Model):
  """The RetinaNet model class."""

  def __init__(self,
               backbone: tf_keras.Model,
               decoder: tf_keras.Model,
               head: tf_keras.layers.Layer,
               detection_generator: tf_keras.layers.Layer,
               anchor_boxes: Mapping[str, tf.Tensor] | None = None,
               min_level: Optional[int] = None,
               max_level: Optional[int] = None,
               num_scales: Optional[int] = None,
               aspect_ratios: Optional[List[float]] = None,
               anchor_size: Optional[float] = None,
               **kwargs):
    """Detection initialization function.

    Args:
      backbone: `tf_keras.Model` a backbone network.
      decoder: `tf_keras.Model` a decoder network.
      head: `RetinaNetHead`, the RetinaNet head.
      detection_generator: the detection generator.
      anchor_boxes: a dict of tensors which includes multilevel anchors.
        - key: `str`, the level of the multilevel predictions.
        - values: `Tensor`, the anchor coordinates of a particular feature
            level, whose shape is [height_l, width_l, 4 *
            num_anchors_per_location_l].
        If provided, these anchors will be used for inference (training=False).
      min_level: Minimum level in output feature maps.
      max_level: Maximum level in output feature maps.
      num_scales: A number representing intermediate scales added
        on each level. For instances, num_scales=2 adds one additional
        intermediate anchor scales [2^0, 2^0.5] on each level.
      aspect_ratios: A list representing the aspect raito
        anchors added on each level. The number indicates the ratio of width to
        height. For instances, aspect_ratios=[1.0, 2.0, 0.5] adds three anchors
        on each scale level.
      anchor_size: A number representing the scale of size of the base
        anchor to the feature stride 2^level.
      **kwargs: keyword arguments to be passed.
    """
    super(RetinaNetModel, self).__init__(**kwargs)
    self._config_dict = {
        'backbone': backbone,
        'decoder': decoder,
        'head': head,
        'detection_generator': detection_generator,
        'min_level': min_level,
        'max_level': max_level,
        'num_scales': num_scales,
        'aspect_ratios': aspect_ratios,
        'anchor_size': anchor_size,
    }
    self._backbone = backbone
    self._decoder = decoder
    self._head = head
    self._detection_generator = detection_generator
    self._anchor_boxes = anchor_boxes

  def call(
      self,  # pytype: disable=annotation-type-mismatch
      images: Union[tf.Tensor, Sequence[tf.Tensor]],
      image_shape: Optional[tf.Tensor] = None,
      anchor_boxes: Mapping[str, tf.Tensor] | None = None,
      output_intermediate_features: bool = False,
      training: bool = None,
  ) -> Mapping[str, tf.Tensor]:
    """Forward pass of the RetinaNet model.

    Args:
      images: `Tensor` or a sequence of `Tensor`, the input batched images to
        the backbone network, whose shape(s) is [batch, height, width, 3]. If it
        is a sequence of `Tensor`, we will assume the anchors are generated
        based on the shape of the first image(s).
      image_shape: `Tensor`, the actual shape of the input images, whose shape
        is [batch, 2] where the last dimension is [height, width]. Note that
        this is the actual image shape excluding paddings. For example, images
        in the batch may be resized into different shapes before padding to the
        fixed size.
      anchor_boxes: the anchor boxes to use for inference (training=False) if
        not provided in the init.
      output_intermediate_features: `bool` indicating whether to return the
        intermediate feature maps generated by backbone and decoder.
      training: `bool`, indicating whether it is in training mode.

    Returns:
      scores: a dict of tensors which includes scores of the predictions.
        - key: `str`, the level of the multilevel predictions.
        - values: `Tensor`, the box scores predicted from a particular feature
            level, whose shape is
            [batch, height_l, width_l, num_classes * num_anchors_per_location].
      boxes: a dict of tensors which includes coordinates of the predictions.
        - key: `str`, the level of the multilevel predictions.
        - values: `Tensor`, the box coordinates predicted from a particular
            feature level, whose shape is
            [batch, height_l, width_l, 4 * num_anchors_per_location].
      attributes: a dict of (attribute_name, attribute_predictions). Each
        attribute prediction is a dict that includes:
        - key: `str`, the level of the multilevel predictions.
        - values: `Tensor`, the attribute predictions from a particular
            feature level, whose shape is
            [batch, height_l, width_l, att_size * num_anchors_per_location].
    """
    outputs = {}
    # Feature extraction.
    features = self.backbone(images)
    if output_intermediate_features:
      outputs.update(
          {'backbone_{}'.format(k): v for k, v in features.items()})
    if self.decoder:
      features = self.decoder(features)
    if output_intermediate_features:
      outputs.update(
          {'decoder_{}'.format(k): v for k, v in features.items()})

    # Dense prediction. `raw_attributes` can be empty.
    raw_scores, raw_boxes, raw_attributes = self.head(features)
    outputs.update({
        'cls_outputs': raw_scores,
        'box_outputs': raw_boxes,
    })

    if training:
      if raw_attributes:
        outputs.update({'attribute_outputs': raw_attributes})
      return outputs
    else:
      if self._anchor_boxes is not None:
        batch_size = tf.shape(raw_boxes[str(self._config_dict['min_level'])])[0]
        anchor_boxes = collections.OrderedDict()
        for level, boxes in self._anchor_boxes.items():
          anchor_boxes[level] = tf.tile(boxes[None, ...], [batch_size, 1, 1, 1])
      elif anchor_boxes is None:
        # Generate anchor boxes for this batch if not provided.
        if isinstance(images, Sequence):
          primary_images = images[0]
        elif isinstance(images, tf.Tensor):
          primary_images = images
        else:
          raise ValueError(
              'Input should be a tf.Tensor or a sequence of tf.Tensor, not {}.'
              .format(type(images)))

        _, image_height, image_width, _ = primary_images.get_shape().as_list()
        anchor_boxes = anchor.Anchor(
            min_level=self._config_dict['min_level'],
            max_level=self._config_dict['max_level'],
            num_scales=self._config_dict['num_scales'],
            aspect_ratios=self._config_dict['aspect_ratios'],
            anchor_size=self._config_dict['anchor_size'],
            image_size=(image_height, image_width)).multilevel_boxes
        for l in anchor_boxes:
          anchor_boxes[l] = tf.tile(
              tf.expand_dims(anchor_boxes[l], axis=0),
              [tf.shape(primary_images)[0], 1, 1, 1])

      # Post-processing.
      final_results = self.detection_generator(raw_boxes, raw_scores,
                                               anchor_boxes, image_shape,
                                               raw_attributes)

      def _update_decoded_results():
        outputs.update({
            'decoded_boxes': final_results['decoded_boxes'],
            'decoded_box_scores': final_results['decoded_box_scores'],
        })
        if final_results.get('decoded_box_attributes') is not None:
          outputs['decoded_box_attributes'] = final_results[
              'decoded_box_attributes'
          ]

      if self.detection_generator.get_config()['apply_nms']:
        outputs.update({
            'detection_boxes': final_results['detection_boxes'],
            'detection_scores': final_results['detection_scores'],
            'detection_classes': final_results['detection_classes'],
            'num_detections': final_results['num_detections'],
        })
        # Users can choose to include the decoded results (boxes before NMS) in
        # the output tensor dict even if `apply_nms` is set to `True`.
        if self.detection_generator.get_config()['return_decoded']:
          _update_decoded_results()
      else:
        _update_decoded_results()

      if raw_attributes:
        outputs.update({
            'attribute_outputs': raw_attributes,
            'detection_attributes': final_results['detection_attributes'],
        })
      return outputs

  @property
  def checkpoint_items(
      self) -> Mapping[str, Union[tf_keras.Model, tf_keras.layers.Layer]]:
    """Returns a dictionary of items to be additionally checkpointed."""
    items = dict(backbone=self.backbone, head=self.head)
    if self.decoder is not None:
      items.update(decoder=self.decoder)

    return items

  @property
  def backbone(self) -> tf_keras.Model:
    return self._backbone

  @property
  def decoder(self) -> tf_keras.Model:
    return self._decoder

  @property
  def head(self) -> tf_keras.layers.Layer:
    return self._head

  @property
  def detection_generator(self) -> tf_keras.layers.Layer:
    return self._detection_generator

  @property
  def anchor_boxes(self) -> Mapping[str, tf.Tensor] | None:
    return self._anchor_boxes

  def get_config(self) -> Mapping[str, Any]:
    return self._config_dict

  @classmethod
  def from_config(cls, config):
    return cls(**config)
